{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda3/lib/python3.6/site-packages/pandas/core/generic.py:5430: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n",
      "  self._update_inplace(new_data)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 550068 entries, 0 to 550067\n",
      "Data columns (total 19 columns):\n",
      "Occupation                       550068 non-null int64\n",
      "Product_Category_1               550068 non-null int64\n",
      "Product_Category_2               550068 non-null float64\n",
      "Product_Category_3               550068 non-null float64\n",
      "Purchase                         550068 non-null int64\n",
      "Gender_M                         550068 non-null uint8\n",
      "Age_18-25                        550068 non-null uint8\n",
      "Age_26-35                        550068 non-null uint8\n",
      "Age_36-45                        550068 non-null uint8\n",
      "Age_46-50                        550068 non-null uint8\n",
      "Age_51-55                        550068 non-null uint8\n",
      "Age_55+                          550068 non-null uint8\n",
      "City_Category_B                  550068 non-null uint8\n",
      "City_Category_C                  550068 non-null uint8\n",
      "Stay_In_Current_City_Years_1     550068 non-null uint8\n",
      "Stay_In_Current_City_Years_2     550068 non-null uint8\n",
      "Stay_In_Current_City_Years_3     550068 non-null uint8\n",
      "Stay_In_Current_City_Years_4+    550068 non-null uint8\n",
      "Marital_Status_1                 550068 non-null uint8\n",
      "dtypes: float64(2), int64(3), uint8(14)\n",
      "memory usage: 28.3 MB\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "df = pd.read_csv('train.csv')\n",
    "\n",
    "df1 = df[['Gender', 'Age', 'Occupation', 'City_Category',\n",
    "'Stay_In_Current_City_Years', 'Marital_Status', 'Product_Category_1',\n",
    "'Product_Category_2', 'Product_Category_3', 'Purchase']]\n",
    "\n",
    "df1['Product_Category_2'].fillna(0, inplace=True)\n",
    "df1['Product_Category_3'].fillna(0, inplace=True)\n",
    "\n",
    "object_names = ['Gender','Age',\n",
    "                'City_Category',\n",
    "                'Stay_In_Current_City_Years',\n",
    "                'Marital_Status']\n",
    "\n",
    "dataset = pd.get_dummies(df1,\n",
    "                         columns=object_names,\n",
    "                         drop_first=True)\n",
    "\n",
    "dataset.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_df = dataset.sample(frac=0.20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(110014, 19)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(550068, 19)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "23961"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.max(dataset['Purchase'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Occupation                       False\n",
       "Product_Category_1               False\n",
       "Product_Category_2               False\n",
       "Product_Category_3               False\n",
       "Purchase                         False\n",
       "Gender_M                         False\n",
       "Age_18-25                        False\n",
       "Age_26-35                        False\n",
       "Age_36-45                        False\n",
       "Age_46-50                        False\n",
       "Age_51-55                        False\n",
       "Age_55+                          False\n",
       "City_Category_B                  False\n",
       "City_Category_C                  False\n",
       "Stay_In_Current_City_Years_1     False\n",
       "Stay_In_Current_City_Years_2     False\n",
       "Stay_In_Current_City_Years_3     False\n",
       "Stay_In_Current_City_Years_4+    False\n",
       "Marital_Status_1                 False\n",
       "dtype: bool"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.isna().any()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.axes._subplots.AxesSubplot at 0x11cb53790>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAD8CAYAAAC7IukgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAFypJREFUeJzt3X+M3PWd3/Hnu/bB+bgQTGhWyEa101h3dfC1R1bE11TRKlRg4HTmJJCMUDE5S9ZFcJerXDWm+YNTEiRoy9HQJlRu7MZEKA7H5YR1mPoswiiqFMyPwGEcjngDLmxw4VIbDidK6Kbv/jGfJV+W2V0z89mdWc/zIY32O+/v5/ud73u+O/va73e+sxuZiSRJNf2Dfm+AJOn0Y7hIkqozXCRJ1RkukqTqDBdJUnWGiySpOsNFklSd4SJJqs5wkSRVt7TfG1Dbeeedl6tWrepq2Z/85CecddZZdTdoERnm/oe5dxju/oe5d/hl/08++eSPM/Mf1lrvaRcuq1at4oknnuhq2VarxdjYWN0NWkSGuf9h7h2Gu/9h7h1+2X9E/K+a6/W0mCSpOsNFklSd4SJJqs5wkSRVZ7hIkqozXCRJ1RkukqTqDBdJUnWGiySputPuE/p6b1Ztf/Dt6W3rJrmhcX8+Hb3tygV5HEn94ZGLJKk6w0WSVJ3hIkmqznCRJFVnuEiSqjNcJEnVGS6SpOoMF0lSdYaLJKk6w0WSVJ3hIkmqznCRJFVnuEiSqjNcJEnVzRkuEbErIl6LiGcbtf8QEX8bEc9ExF9GxDmNeTdHxHhEPB8RlzXqG0ptPCK2N+qrI+JgRByJiG9GxBmlfma5P17mr6rVtCRpfp3KkcvXgA3TageACzPzt4AfADcDRMRaYBPwkbLMVyJiSUQsAb4MXA6sBa4tYwFuB+7MzDXACWBLqW8BTmTmh4E7yzhJ0iIwZ7hk5neA49Nqf52Zk+Xuo8DKMr0R2JOZP8/MF4Fx4OJyG8/MFzLzLWAPsDEiAvgkcH9ZfjdwVWNdu8v0/cAlZbwkacDVeM/lD4CHyvQK4OXGvIlSm6n+AeD1RlBN1d+xrjL/jTJekjTgevo3xxHxOWASuHeq1GFY0jnEcpbxs62r03ZsBbYCjIyM0Gq1Zt7oWZw8ebLrZRerbesm354eWfbO+/Np0J7nYdz3TcPc/zD3DvPXf9fhEhGbgd8FLsnMqR/6E8AFjWErgVfKdKf6j4FzImJpOTppjp9a10RELAXez7TTc1MycwewA2B0dDTHxsa66qnVatHtsovVDdsffHt627pJ7jjU0+8bp+zodWML8jinahj3fdMw9z/MvcP89d/VabGI2AB8Fvi9zPxpY9ZeYFO50ms1sAZ4DHgcWFOuDDuD9pv+e0soPQJcXZbfDDzQWNfmMn018O1GiEmSBticv6ZGxDeAMeC8iJgAbqF9ddiZwIHyHvujmfmHmXk4Iu4Dvk/7dNmNmfmLsp6bgP3AEmBXZh4uD/FZYE9EfBF4CthZ6juBr0fEOO0jlk0V+pUkLYA5wyUzr+1Q3tmhNjX+VuDWDvV9wL4O9RdoX002vf4z4Jq5tk+SNHj8hL4kqTrDRZJUneEiSarOcJEkVWe4SJKqM1wkSdUZLpKk6gwXSVJ1hoskqTrDRZJUneEiSarOcJEkVWe4SJKqM1wkSdUZLpKk6gwXSVJ1hoskqTrDRZJUneEiSapuab83QMNp1fYH+/bYR2+7sm+PLQ0Lj1wkSdUZLpKk6gwXSVJ1c4ZLROyKiNci4tlG7dyIOBARR8rX5aUeEXFXRIxHxDMRcVFjmc1l/JGI2NyofzQiDpVl7oqImO0xJEmD71SOXL4GbJhW2w48nJlrgIfLfYDLgTXlthW4G9pBAdwCfAy4GLilERZ3l7FTy22Y4zEkSQNuznDJzO8Ax6eVNwK7y/Ru4KpG/Z5sexQ4JyLOBy4DDmTm8cw8ARwANpR5Z2fmdzMzgXumravTY0iSBly3lyKPZOYxgMw8FhEfLPUVwMuNcROlNlt9okN9tsd4l4jYSvvoh5GREVqtVldNnTx5sutlF6tt6ybfnh5Z9s77p6tO+3gY933TMPc/zL3D/PVf+3Mu0aGWXdTfk8zcAewAGB0dzbGxsfe6CqD9Q6fbZRerGxqfN9m2bpI7Dp3+H306et3Yu2rDuO+bhrn/Ye4d5q//bq8We7Wc0qJ8fa3UJ4ALGuNWAq/MUV/ZoT7bY0iSBly34bIXmLriazPwQKN+fblqbD3wRjm1tR+4NCKWlzfyLwX2l3lvRsT6cpXY9dPW1ekxJEkDbs5zIBHxDWAMOC8iJmhf9XUbcF9EbAFeAq4pw/cBVwDjwE+BTwFk5vGI+ALweBn3+cycukjg07SvSFsGPFRuzPIYkqQBN2e4ZOa1M8y6pMPYBG6cYT27gF0d6k8AF3ao/59OjyFJGnx+Ql+SVJ3hIkmqznCRJFVnuEiSqjNcJEnVGS6SpOoMF0lSdYaLJKk6w0WSVJ3hIkmqznCRJFVnuEiSqjNcJEnVGS6SpOoMF0lSdYaLJKk6w0WSVJ3hIkmqznCRJFVnuEiSqjNcJEnVGS6SpOp6CpeI+NcRcTgino2Ib0TEr0bE6og4GBFHIuKbEXFGGXtmuT9e5q9qrOfmUn8+Ii5r1DeU2nhEbO9lWyVJC6frcImIFcAfA6OZeSGwBNgE3A7cmZlrgBPAlrLIFuBEZn4YuLOMIyLWluU+AmwAvhIRSyJiCfBl4HJgLXBtGStJGnC9nhZbCiyLiKXArwHHgE8C95f5u4GryvTGcp8y/5KIiFLfk5k/z8wXgXHg4nIbz8wXMvMtYE8ZK0kacF2HS2b+CPiPwEu0Q+UN4Eng9cycLMMmgBVlegXwcll2soz/QLM+bZmZ6pKkAbe02wUjYjntI4nVwOvAn9M+hTVdTi0yw7yZ6p2CLzvUiIitwFaAkZERWq3WbJs+o5MnT3a97GK1bd3k29Mjy955/3TVaR8P475vGub+h7l3mL/+uw4X4F8CL2bm3wFExLeAfw6cExFLy9HJSuCVMn4CuACYKKfR3g8cb9SnNJeZqf4OmbkD2AEwOjqaY2NjXTXUarXodtnF6obtD749vW3dJHcc6uVbYnE4et3Yu2rDuO+bhrn/Ye4d5q//Xt5zeQlYHxG/Vt47uQT4PvAIcHUZsxl4oEzvLfcp87+dmVnqm8rVZKuBNcBjwOPAmnL12Rm03/Tf28P2SpIWSNe/pmbmwYi4H/geMAk8Rfvo4UFgT0R8sdR2lkV2Al+PiHHaRyybynoOR8R9tINpErgxM38BEBE3AftpX4m2KzMPd7u9kqSF09M5kMy8BbhlWvkF2ld6TR/7M+CaGdZzK3Brh/o+YF8v2yhJWnin/wl2aZpVjfeZpmxbN/mO95/mw9HbrpzX9UuDxD//IkmqznCRJFVnuEiSqjNcJEnVGS6SpOoMF0lSdYaLJKk6w0WSVJ3hIkmqznCRJFVnuEiSqjNcJEnVGS6SpOoMF0lSdYaLJKk6w0WSVJ3hIkmqznCRJFXnvzmWhkCnf+3cNF//5tl/7Ty8PHKRJFVnuEiSquspXCLinIi4PyL+NiKei4jfiYhzI+JARBwpX5eXsRERd0XEeEQ8ExEXNdazuYw/EhGbG/WPRsShssxdERG9bK8kaWH0euTyJeB/ZOZvAv8UeA7YDjycmWuAh8t9gMuBNeW2FbgbICLOBW4BPgZcDNwyFUhlzNbGcht63F5J0gLoOlwi4mzgE8BOgMx8KzNfBzYCu8uw3cBVZXojcE+2PQqcExHnA5cBBzLzeGaeAA4AG8q8szPzu5mZwD2NdUmSBlgvRy4fAv4O+O8R8VREfDUizgJGMvMYQPn6wTJ+BfByY/mJUputPtGhLkkacL1cirwUuAj4o8w8GBFf4penwDrp9H5JdlF/94ojttI+fcbIyAitVmuWzZjZyZMnu152sdq2bvLt6ZFl77w/TBai935+b83V23z1vxheT8P4um+ar/57CZcJYCIzD5b799MOl1cj4vzMPFZObb3WGH9BY/mVwCulPjat3ir1lR3Gv0tm7gB2AIyOjubY2FinYXNqtVp0u+xi1fxsw7Z1k9xxaDg/+rQQvR+9bmxe1z+buT7DMl/997PnUzWMr/um+eq/69Nimfm/gZcj4jdK6RLg+8BeYOqKr83AA2V6L3B9uWpsPfBGOW22H7g0IpaXN/IvBfaXeW9GxPpyldj1jXVJkgZYr7+q/BFwb0ScAbwAfIp2YN0XEVuAl4Bryth9wBXAOPDTMpbMPB4RXwAeL+M+n5nHy/Snga8By4CHyu20NNcnqCVpMekpXDLzaWC0w6xLOoxN4MYZ1rML2NWh/gRwYS/bKElaeH5CX5JUneEiSarOcJEkVWe4SJKqM1wkSdUZLpKk6gwXSVJ1hoskqTrDRZJUneEiSarOcJEkVWe4SJKqM1wkSdUZLpKk6gwXSVJ1hoskqTrDRZJUneEiSarOcJEkVWe4SJKqM1wkSdUZLpKk6gwXSVJ1PYdLRCyJiKci4q/K/dURcTAijkTENyPijFI/s9wfL/NXNdZxc6k/HxGXNeobSm08Irb3uq2SpIVR48jlM8Bzjfu3A3dm5hrgBLCl1LcAJzLzw8CdZRwRsRbYBHwE2AB8pQTWEuDLwOXAWuDaMlaSNOB6CpeIWAlcCXy13A/gk8D9Zchu4KoyvbHcp8y/pIzfCOzJzJ9n5ovAOHBxuY1n5guZ+Rawp4yVJA24pT0u/5+Afwu8r9z/APB6Zk6W+xPAijK9AngZIDMnI+KNMn4F8Ghjnc1lXp5W/1injYiIrcBWgJGREVqtVlfNnDx5sutle7Vt3eTcg+bZyLLB2I5+WIje+/W9BXP3Nl/997PnU9XP1/0gmK/+uw6XiPhd4LXMfDIixqbKHYbmHPNmqnc6qsoONTJzB7ADYHR0NMfGxjoNm1Or1aLbZXt1w/YH+/K4TdvWTXLHoV5/31icFqL3o9eNzev6ZzPX99d89d/Pnk9VP1/3g2C++u/lu+njwO9FxBXArwJn0z6SOScilpajl5XAK2X8BHABMBERS4H3A8cb9SnNZWaqS5IGWNfvuWTmzZm5MjNX0X5D/tuZeR3wCHB1GbYZeKBM7y33KfO/nZlZ6pvK1WSrgTXAY8DjwJpy9dkZ5TH2dru9kqSFMx/nAT4L7ImILwJPATtLfSfw9YgYp33EsgkgMw9HxH3A94FJ4MbM/AVARNwE7AeWALsy8/A8bK8kqbIq4ZKZLaBVpl+gfaXX9DE/A66ZYflbgVs71PcB+2ps46k49KM3BuK9D0la7PyEviSpOsNFklSd4SJJqs5wkSRVZ7hIkqozXCRJ1RkukqTqhvMPSUk67a06xc+sbVs3WfXzbUdvu7LauhYzj1wkSdUZLpKk6jwtJi2QUz1NI50OPHKRJFVnuEiSqjNcJEnVGS6SpOoMF0lSdYaLJKk6w0WSVJ3hIkmqznCRJFVnuEiSqjNcJEnVdR0uEXFBRDwSEc9FxOGI+EypnxsRByLiSPm6vNQjIu6KiPGIeCYiLmqsa3MZfyQiNjfqH42IQ2WZuyIiemlWkrQwejlymQS2ZeY/AdYDN0bEWmA78HBmrgEeLvcBLgfWlNtW4G5ohxFwC/Ax4GLglqlAKmO2Npbb0MP2SpIWSNfhkpnHMvN7ZfpN4DlgBbAR2F2G7QauKtMbgXuy7VHgnIg4H7gMOJCZxzPzBHAA2FDmnZ2Z383MBO5prEuSNMCq/Mn9iFgF/DZwEBjJzGPQDqCI+GAZtgJ4ubHYRKnNVp/oUO/0+FtpH+EwMjJCq9Xqqo+RZe3/Sjeshrn/Ye4d5q//bl+LNZxqP7V772fP3Th58uS8bHPP4RIRvw78BfAnmfn3s7wt0mlGdlF/dzFzB7ADYHR0NMfGxubY6s7+870PcMeh4f0XN9vWTQ5t/8PcO8xf/0evG6u+zlN1qv+6uHbv/ey5G61Wi25/Zs6mp6vFIuJXaAfLvZn5rVJ+tZzSonx9rdQngAsai68EXpmjvrJDXZI04Hq5WiyAncBzmflnjVl7gakrvjYDDzTq15erxtYDb5TTZ/uBSyNieXkj/1Jgf5n3ZkSsL491fWNdkqQB1sux4MeBfwUcioinS+3fAbcB90XEFuAl4Joybx9wBTAO/BT4FEBmHo+ILwCPl3Gfz8zjZfrTwNeAZcBD5SZJGnBdh0tm/k86vy8CcEmH8QncOMO6dgG7OtSfAC7sdhslSf3hJ/QlSdUZLpKk6gwXSVJ1hoskqTrDRZJUneEiSapueP/ehaR5t+oU/wSLTj8euUiSqjNcJEnVGS6SpOoMF0lSdYaLJKk6w0WSVJ3hIkmqznCRJFVnuEiSqjNcJEnVGS6SpOoMF0lSdYaLJKk6w0WSVJ3hIkmqbuD/n0tEbAC+BCwBvpqZt/V5kyRpRv38HzZHb7uyb4893UAfuUTEEuDLwOXAWuDaiFjb362SJM1loMMFuBgYz8wXMvMtYA+wsc/bJEmaw6CHywrg5cb9iVKTJA2wQX/PJTrU8l2DIrYCW8vdkxHxfJePdx7w4y6XXfT+eIj7H+beYbj7P516j9u7Wmyq/39Uc1sGPVwmgAsa91cCr0wflJk7gB29PlhEPJGZo72uZ7Ea5v6HuXcY7v6HuXeYv/4H/bTY48CaiFgdEWcAm4C9fd4mSdIcBvrIJTMnI+ImYD/tS5F3ZebhPm+WJGkOAx0uAJm5D9i3QA/X86m1RW6Y+x/m3mG4+x/m3mGe+o/Md70/LklSTwb9PRdJ0iJkuBQRsSEino+I8YjY3u/tqSUijkbEoYh4OiKeKLVzI+JARBwpX5eXekTEXeU5eCYiLmqsZ3MZfyQiNvern7lExK6IeC0inm3UqvUbER8tz+d4WbbT5fJ9MUPvfxoRPyr7/+mIuKIx7+bSx/MRcVmj3vG1UC6sOViek2+Wi2wGQkRcEBGPRMRzEXE4Ij5T6sOy72fqv3/7PzOH/kb7YoEfAh8CzgD+Bljb7+2q1NtR4LxptX8PbC/T24Hby/QVwEO0P1+0HjhY6ucCL5Svy8v08n73NkO/nwAuAp6dj36Bx4DfKcs8BFze757n6P1PgX/TYeza8n1+JrC6fP8vme21ANwHbCrT/xX4dL97bvRzPnBRmX4f8IPS47Ds+5n679v+98ilbdj+zMxGYHeZ3g1c1ajfk22PAudExPnAZcCBzDyemSeAA8CGhd7oU5GZ3wGOTytX6bfMOzszv5vtV9g9jXX13Qy9z2QjsCczf56ZLwLjtF8HHV8L5bf0TwL3l+Wbz2PfZeaxzPxemX4TeI72X/MYln0/U/8zmff9b7i0nc5/ZiaBv46IJ6P9lwwARjLzGLS/KYEPlvpMz8Nif35q9buiTE+vD7qbyqmfXVOnhXjvvX8AeD0zJ6fVB05ErAJ+GzjIEO77af1Dn/a/4dJ2Sn9mZpH6eGZeRPsvS98YEZ+YZexMz8Pp+vy8134X4/NwN/CPgX8GHAPuKPXTsveI+HXgL4A/ycy/n21oh9rp2H/f9r/h0nZKf2ZmMcrMV8rX14C/pH3Y+2o5zKd8fa0Mn+l5WOzPT61+J8r09PrAysxXM/MXmfn/gP9Ge//De+/9x7RPHS2dVh8YEfErtH+w3puZ3yrlodn3nfrv5/43XNpOyz8zExFnRcT7pqaBS4Fnafc2dRXMZuCBMr0XuL5cSbMeeKOcStgPXBoRy8th9aWltlhU6bfMezMi1pdz0Nc31jWQpn6wFr9Pe/9Du/dNEXFmRKwG1tB+w7rja6G8z/AIcHVZvvk89l3ZHzuB5zLzzxqzhmLfz9R/X/d/v69yGJQb7atHfkD7SonP9Xt7KvX0IdpXe/wNcHiqL9rnTx8GjpSv55Z60P7nbD8EDgGjjXX9Ae03/caBT/W7t1l6/gbtw///S/u3sC01+wVGywv0h8B/oXwQeRBuM/T+9dLbM+UHyvmN8Z8rfTxP48qnmV4L5fvpsfKc/DlwZr97bmzbv6B9muYZ4Olyu2KI9v1M/fdt//sJfUlSdZ4WkyRVZ7hIkqozXCRJ1RkukqTqDBdJUnWGiySpOsNFklSd4SJJqu7/A41VPraD34arAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x11cb53d50>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataset['Purchase'].hist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_df['log_purchase'] = np.log(sample_df['Purchase'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAD8CAYAAACcjGjIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAFAhJREFUeJzt3X+s5XWd3/HnS0bXka4LitxQhnZodmJlmYAwwWlNzF3ZwoBGaLMkELoMhnYag1u3mWQ79h+yuiaa1Lo1cU0mMnXY7spSdw1EUZwgt5smioC/AJEwi1Ruoc5uB1lHs7rXffeP+xk93M+5c++d++N8Cc9HcnK+3/f3c77f9zk3c1/z/XHuN1WFJEmjXjbpBiRJw2M4SJI6hoMkqWM4SJI6hoMkqWM4SJI6hoMkqWM4SJI6hoMkqbNp0g2crDPOOKO2bt066TYA+NGPfsSpp5466TYWNfT+YPg92t/q2N/qrFV/Dz300F9X1euWNbiqXpSPiy++uIbivvvum3QLJzT0/qqG36P9rY79rc5a9Qc8WMv8HethJUlSx3CQJHUMB0lSx3CQJHUMB0lSx3CQJHUMB0lSx3CQJHUMB0lS50X75zMkaeu+z63JevZun+PGFa7rqQ++bU22PVTuOUiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOssKhyRPJXk4yTeSPNhqr0lyKMkT7fn0Vk+SjyY5nORbSS4aWc/uNv6JJLtH6he39R9ur81av1FJ0vKtZM/h16vqwqra0eb3AfdW1Tbg3jYPcAWwrT32AB+H+TABbgHeBFwC3HI8UNqYPSOv23XS70iStGqrOax0FXCwTR8Erh6p31bzvgKcluQs4HLgUFUdrarngEPArrbs1VX15aoq4LaRdUmSJmC54VDAF5M8lGRPq01V1bMA7fnMVj8beHrktbOtdqL67Ji6JGlClns/hzdX1TNJzgQOJfnOCcaOO19QJ1HvVzwfTHsApqammJmZOWHTG+XYsWOD6WWcofcHw+/R/lZnvfrbu31uTdYztXnl69rIz3sSP99lhUNVPdOejyT5DPPnDL6f5KyqerYdGjrShs8C54y8fAvwTKtPL6jPtPqWMePH9bEf2A+wY8eOmp6eHjdsw83MzDCUXsYZen8w/B7tb3XWq7+V3qBnMXu3z/Hhh1d277Onrp9ek20vxyR+vkseVkpyapJfPj4NXAY8AtwFHL/iaDdwZ5u+C7ihXbW0E3i+HXa6B7gsyentRPRlwD1t2Q+T7GxXKd0wsi5J0gQsJyqngM+0q0s3AX9SVV9I8gBwR5KbgO8B17TxdwNXAoeBHwPvBKiqo0neDzzQxr2vqo626XcBnwQ2A59vD0nShCwZDlX1JHDBmPr/Ay4dUy/g5kXWdQA4MKb+IHD+MvqVJG0AvyEtSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeosOxySnJLk60k+2+bPTXJ/kieS/GmSV7T6L7X5w2351pF1vLfVH09y+Uh9V6sdTrJv7d6eJOlkrGTP4T3AYyPzHwI+UlXbgOeAm1r9JuC5qvpV4CNtHEnOA64Ffg3YBfxhC5xTgI8BVwDnAde1sZKkCVlWOCTZArwN+ESbD/BW4NNtyEHg6jZ9VZunLb+0jb8KuL2qflJV3wUOA5e0x+GqerKqfgrc3sZKkiZkuXsOfwD8LvD3bf61wA+qaq7NzwJnt+mzgacB2vLn2/if1xe8ZrG6JGlCNi01IMnbgSNV9VCS6ePlMUNriWWL1ccFVI2pkWQPsAdgamqKmZmZxRvfQMeOHRtML+MMvT8Yfo/2tzrr1d/e7XNLD1qGqc0rX9dGft6T+PkuGQ7Am4F3JLkSeCXwaub3JE5LsqntHWwBnmnjZ4FzgNkkm4BfAY6O1I8bfc1i9Reoqv3AfoAdO3bU9PT0MtpffzMzMwyll3GG3h8Mv0f7W5316u/GfZ9bk/Xs3T7Hhx9ezq/DX3jq+uk12fZyTOLnu+Rhpap6b1VtqaqtzJ9Q/lJVXQ/cB/xmG7YbuLNN39Xmacu/VFXV6te2q5nOBbYBXwUeALa1q59e0bZx15q8O0nSSVlZVL7QfwRuT/L7wNeBW1v9VuCPkhxmfo/hWoCqejTJHcC3gTng5qr6GUCSdwP3AKcAB6rq0VX0JUlapRWFQ1XNADNt+knmrzRaOOZvgWsWef0HgA+Mqd8N3L2SXiRJ68dvSEuSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKmzmjvBSdJL1tY1un/1cuzdPvfz+2U/9cG3bcg23XOQJHUMB0lSx3CQJHUMB0lSx3CQJHUMB0lSx3CQJHUMB0lSx3CQJHUMB0lSx3CQJHUMB0lSZ8lwSPLKJF9N8s0kjyb5vVY/N8n9SZ5I8qdJXtHqv9TmD7flW0fW9d5WfzzJ5SP1Xa12OMm+tX+bkqSVWM6ew0+At1bVBcCFwK4kO4EPAR+pqm3Ac8BNbfxNwHNV9avAR9o4kpwHXAv8GrAL+MMkpyQ5BfgYcAVwHnBdGytJmpAlw6HmHWuzL2+PAt4KfLrVDwJXt+mr2jxt+aVJ0uq3V9VPquq7wGHgkvY4XFVPVtVPgdvbWEnShCzrnEP7H/43gCPAIeAvgR9U1VwbMguc3abPBp4GaMufB147Wl/wmsXqkqQJWdbNfqrqZ8CFSU4DPgO8Ydyw9pxFli1WHxdQNaZGkj3AHoCpqSlmZmZO3PgGOXbs2GB6GWfo/cHwe7S/1Vmv/vZun1t60DJMbV67da2H0f426ue8ojvBVdUPkswAO4HTkmxqewdbgGfasFngHGA2ySbgV4CjI/XjRl+zWH3h9vcD+wF27NhR09PTK2l/3czMzDCUXsYZen8w/B7tb3XWq78b1+hubHu3z/Hhh4d7Y8zR/p66fnpDtrmcq5Ve1/YYSLIZ+A3gMeA+4DfbsN3AnW36rjZPW/6lqqpWv7ZdzXQusA34KvAAsK1d/fQK5k9a37UWb06SdHKWE5VnAQfbVUUvA+6oqs8m+TZwe5LfB74O3NrG3wr8UZLDzO8xXAtQVY8muQP4NjAH3NwOV5Hk3cA9wCnAgap6dM3eoSRpxZYMh6r6FvDGMfUnmb/SaGH9b4FrFlnXB4APjKnfDdy9jH4lSRvAb0hLkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySpYzhIkjqGgySps2Q4JDknyX1JHkvyaJL3tPprkhxK8kR7Pr3Vk+SjSQ4n+VaSi0bWtbuNfyLJ7pH6xUkebq/5aJKsx5uVJC3PcvYc5oC9VfUGYCdwc5LzgH3AvVW1Dbi3zQNcAWxrjz3Ax2E+TIBbgDcBlwC3HA+UNmbPyOt2rf6tSZJO1pLhUFXPVtXX2vQPgceAs4GrgINt2EHg6jZ9FXBbzfsKcFqSs4DLgUNVdbSqngMOAbvasldX1ZerqoDbRtYlSZqATSsZnGQr8EbgfmCqqp6F+QBJcmYbdjbw9MjLZlvtRPXZMfVx29/D/B4GU1NTzMzMrKT9dXPs2LHB9DLO0PuD4fdof6uzXv3t3T63JuuZ2rx261oPo/1t1M952eGQ5B8Afwb8TlX9zQlOC4xbUCdR74tV+4H9ADt27Kjp6eklut4YMzMzDKWXcYbeHwy/R/tbnfXq78Z9n1uT9ezdPseHH17R/5U31Gh/T10/vSHbXNbVSkleznww/HFV/Xkrf78dEqI9H2n1WeCckZdvAZ5Zor5lTF2SNCHLuVopwK3AY1X1X0YW3QUcv+JoN3DnSP2GdtXSTuD5dvjpHuCyJKe3E9GXAfe0ZT9MsrNt64aRdUmSJmA5+1FvBn4LeDjJN1rtPwEfBO5IchPwPeCatuxu4ErgMPBj4J0AVXU0yfuBB9q491XV0Tb9LuCTwGbg8+0hSZqQJcOhqv4X488LAFw6ZnwBNy+yrgPAgTH1B4Hzl+pFkrQx/Ia0JKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKljOEiSOoaDJKmznNuEStIJbd33uRMu37t9jhuXGKNhcc9BktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktRZMhySHEhyJMkjI7XXJDmU5In2fHqrJ8lHkxxO8q0kF428Zncb/0SS3SP1i5M83F7z0SRZ6zcpSVqZ5ew5fBLYtaC2D7i3qrYB97Z5gCuAbe2xB/g4zIcJcAvwJuAS4JbjgdLG7Bl53cJtSZI22JLhUFV/ARxdUL4KONimDwJXj9Rvq3lfAU5LchZwOXCoqo5W1XPAIWBXW/bqqvpyVRVw28i6JEkTcrLnHKaq6lmA9nxmq58NPD0ybrbVTlSfHVOXJE3QWt/PYdz5gjqJ+viVJ3uYPwTF1NQUMzMzJ9Hi2jt27Nhgehln6P3B8Hu0vxPbu33uhMunNi89ZpJeTP1t1M/5ZMPh+0nOqqpn26GhI60+C5wzMm4L8EyrTy+oz7T6ljHjx6qq/cB+gB07dtT09PRiQzfUzMwMQ+llnKH3B8Pv0f5ObKkb+ezdPseHHx7uvcVeTP09df30hmzzZA8r3QUcv+JoN3DnSP2GdtXSTuD5dtjpHuCyJKe3E9GXAfe0ZT9MsrNdpXTDyLokSROyZFQm+RTz/+s/I8ks81cdfRC4I8lNwPeAa9rwu4ErgcPAj4F3AlTV0STvBx5o495XVcdPcr+L+SuiNgOfbw9J0gQtGQ5Vdd0iiy4dM7aAmxdZzwHgwJj6g8D5S/UhSdo4fkNaktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQxHCRJneHe+kjSimxd4m5s0kq45yBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSO4SBJ6hgOkqSOf3hPWmOT+AN4e7fP4T9nrSX3HCRJncGEQ5JdSR5PcjjJvkn3I0kvZYMIhySnAB8DrgDOA65Lct5ku5Kkl66hHKS8BDhcVU8CJLkduAr49kS70ovaWh7737t9jhu9mY5eQoYSDmcDT4/MzwJvWq+NrfUJw6H/4ti7fY7pCW17uZ/10D9D6aUmVTXpHkhyDXB5Vf2bNv9bwCVV9dsLxu0B9rTZ1wOPb2ijizsD+OtJN3ECQ+8Pht+j/a2O/a3OWvX3j6vqdcsZOJQ9h1ngnJH5LcAzCwdV1X5g/0Y1tVxJHqyqHZPuYzFD7w+G36P9rY79rc4k+hvECWngAWBbknOTvAK4Frhrwj1J0kvWIPYcqmouybuBe4BTgANV9eiE25Kkl6xBhANAVd0N3D3pPk7S4A51LTD0/mD4Pdrf6tjf6mx4f4M4IS1JGpahnHOQJA2I4bAKSV6Z5KtJvpnk0SS/N+mexklySpKvJ/nspHtZKMlTSR5O8o0kD066n4WSnJbk00m+k+SxJP9s0j0dl+T17XM7/vibJL8z6b5GJfkP7d/GI0k+leSVk+5pVJL3tN4eHcpnl+RAkiNJHhmpvSbJoSRPtOfT17sPw2F1fgK8taouAC4EdiXZOeGexnkP8NikmziBX6+qCwd6KeF/Bb5QVf8UuIABfY5V9Xj73C4ELgZ+DHxmwm39XJKzgX8P7Kiq85m/2OTayXb1C0nOB/4t83+h4QLg7Um2TbYrAD4J7FpQ2wfcW1XbgHvb/LoyHFah5h1rsy9vj0GdxEmyBXgb8IlJ9/Jik+TVwFuAWwGq6qdV9YPJdrWoS4G/rKr/PelGFtgEbE6yCXgVY76/NEFvAL5SVT+uqjngfwL/csI9UVV/ARxdUL4KONimDwJXr3cfhsMqtUM23wCOAIeq6v5J97TAHwC/C/z9pBtZRAFfTPJQ+wb8kPwT4K+A/9YOy30iyamTbmoR1wKfmnQTo6rq/wD/Gfge8CzwfFV9cbJdvcAjwFuSvDbJq4AreeGXcYdkqqqeBWjPZ673Bg2HVaqqn7Xd+i3AJW1XdRCSvB04UlUPTbqXE3hzVV3E/F/kvTnJWybd0IhNwEXAx6vqjcCP2IDd+ZVqXxx9B/A/Jt3LqHZc/CrgXOAfAqcm+deT7eoXquox4EPAIeALwDeBuYk2NSCGwxpphxtm6I8VTtKbgXckeQq4HXhrkv8+2ZZeqKqeac9HmD9efslkO3qBWWB2ZG/w08yHxdBcAXytqr4/6UYW+A3gu1X1V1X1d8CfA/98wj29QFXdWlUXVdVbmD+U88Ske1rE95OcBdCej6z3Bg2HVUjyuiSntenNzP9j+M5ku/qFqnpvVW2pqq3MH3b4UlUN5n9uSU5N8svHp4HLmN/VH4Sq+r/A00le30qXMsw/I38dAzuk1HwP2JnkVUnC/Oc3mBP6AEnObM//CPhXDPNzhPk/J7S7Te8G7lzvDQ7mG9IvUmcBB9vNil4G3FFVg7tcdMCmgM/M/95gE/AnVfWFybbU+W3gj9uhmyeBd064nxdox8r/BfDvJt3LQlV1f5JPA19j/nDN1xneN5H/LMlrgb8Dbq6q5ybdUJJPAdPAGUlmgVuADwJ3JLmJ+dC9Zt378BvSkqSFPKwkSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkzv8HBlFHWWJYZdIAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "sample_df['log_purchase'].hist()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = sample_df.drop(columns=['Purchase','log_purchase'])\n",
    "y = sample_df['log_purchase']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "from sklearn.metrics import mean_squared_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train size is: (88011, 18)\n",
      "Test size is: (22003, 18)\n"
     ]
    }
   ],
   "source": [
    "print(\"Train size is: {}\".format(X_train.shape))\n",
    "print(\"Test size is: {}\".format(X_test.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "sc = StandardScaler()\n",
    "\n",
    "X_train_sc = sc.fit_transform(X_train)\n",
    "X_test_sc = sc.transform(X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Random Forest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "118.10078716278076\n",
      "Train 0.7788117394416607\n",
      "Validation 0.5745580148161614\n",
      "3224.150338493422\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "\n",
    "np.random.seed(7)\n",
    "model = RandomForestRegressor(n_estimators=500,max_depth=60,random_state=7)\n",
    "model.fit(X_train,y_train)\n",
    "\n",
    "end = time.time()\n",
    "print (end - start)\n",
    "\n",
    "y_pred_train = model.predict(X_train)\n",
    "print('Train',np.sum(np.abs(np.exp(y_pred_train)-np.exp(y_train))/np.exp(y_train)<0.25)/y_train.shape[0])\n",
    "\n",
    "y_pred_test = model.predict(X_test)\n",
    "print('Validation',np.sum(np.abs(np.exp(y_pred_test)-np.exp(y_test))/np.exp(y_test)<0.25)/y_test.shape[0])\n",
    "\n",
    "print(np.sqrt(mean_squared_error(np.exp(y_test),np.exp(y_pred_test))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Deep Learning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Building model...\n",
      "Train on 440054 samples, validate on 110014 samples\n",
      "Epoch 1/50\n",
      "440054/440054 [==============================] - 8s 19us/step - loss: 0.8492 - val_loss: 0.3168\n",
      "Epoch 2/50\n",
      "440054/440054 [==============================] - 9s 21us/step - loss: 0.2703 - val_loss: 0.2410\n",
      "Epoch 3/50\n",
      "440054/440054 [==============================] - 13s 29us/step - loss: 0.2233 - val_loss: 0.2150\n",
      "Epoch 4/50\n",
      "440054/440054 [==============================] - 9s 21us/step - loss: 0.2072 - val_loss: 0.1798\n",
      "Epoch 5/50\n",
      "440054/440054 [==============================] - 9s 20us/step - loss: 0.1958 - val_loss: 0.2019\n",
      "Epoch 6/50\n",
      "440054/440054 [==============================] - 8s 19us/step - loss: 0.1941 - val_loss: 0.1639\n",
      "Epoch 7/50\n",
      "440054/440054 [==============================] - 9s 20us/step - loss: 0.1826 - val_loss: 0.1714\n",
      "Epoch 8/50\n",
      "440054/440054 [==============================] - 9s 21us/step - loss: 0.1780 - val_loss: 0.1649\n",
      "Epoch 9/50\n",
      "440054/440054 [==============================] - 10s 24us/step - loss: 0.1829 - val_loss: 0.1796\n",
      "Epoch 10/50\n",
      "440054/440054 [==============================] - 10s 23us/step - loss: 0.1792 - val_loss: 0.1563\n",
      "Epoch 11/50\n",
      "440054/440054 [==============================] - 10s 23us/step - loss: 0.1791 - val_loss: 0.1774\n",
      "Epoch 12/50\n",
      "440054/440054 [==============================] - 10s 22us/step - loss: 0.1729 - val_loss: 0.1681\n",
      "Epoch 13/50\n",
      "440054/440054 [==============================] - 10s 23us/step - loss: 0.1742 - val_loss: 0.1865\n",
      "Epoch 14/50\n",
      "440054/440054 [==============================] - 9s 21us/step - loss: 0.1716 - val_loss: 0.1545\n",
      "Epoch 15/50\n",
      "440054/440054 [==============================] - 9s 21us/step - loss: 0.1693 - val_loss: 0.2089\n",
      "Epoch 16/50\n",
      "440054/440054 [==============================] - 9s 21us/step - loss: 0.1740 - val_loss: 0.1736\n",
      "Epoch 17/50\n",
      "440054/440054 [==============================] - 9s 21us/step - loss: 0.1744 - val_loss: 0.1595\n",
      "Epoch 18/50\n",
      "440054/440054 [==============================] - 9s 21us/step - loss: 0.1703 - val_loss: 0.1547\n",
      "Epoch 19/50\n",
      "440054/440054 [==============================] - 9s 21us/step - loss: 0.1636 - val_loss: 0.1582\n",
      "Epoch 20/50\n",
      "440054/440054 [==============================] - 9s 21us/step - loss: 0.1622 - val_loss: 0.1630\n",
      "Epoch 21/50\n",
      "440054/440054 [==============================] - 9s 20us/step - loss: 0.1607 - val_loss: 0.1532\n",
      "Epoch 22/50\n",
      "440054/440054 [==============================] - 9s 21us/step - loss: 0.1644 - val_loss: 0.1520\n",
      "Epoch 23/50\n",
      "440054/440054 [==============================] - 9s 21us/step - loss: 0.1597 - val_loss: 0.1490\n",
      "Epoch 24/50\n",
      "440054/440054 [==============================] - 9s 20us/step - loss: 0.1603 - val_loss: 0.1547\n",
      "Epoch 25/50\n",
      "440054/440054 [==============================] - 9s 20us/step - loss: 0.1592 - val_loss: 0.1720\n",
      "Epoch 26/50\n",
      "440054/440054 [==============================] - 10s 22us/step - loss: 0.1624 - val_loss: 0.1520\n",
      "Epoch 27/50\n",
      "440054/440054 [==============================] - 11s 24us/step - loss: 0.1584 - val_loss: 0.1492\n",
      "Epoch 28/50\n",
      "440054/440054 [==============================] - 11s 24us/step - loss: 0.1601 - val_loss: 0.1537\n",
      "Epoch 29/50\n",
      "440054/440054 [==============================] - 11s 24us/step - loss: 0.1580 - val_loss: 0.1528\n",
      "Epoch 30/50\n",
      "440054/440054 [==============================] - 10s 24us/step - loss: 0.1564 - val_loss: 0.1921\n",
      "Epoch 31/50\n",
      "440054/440054 [==============================] - 10s 23us/step - loss: 0.1582 - val_loss: 0.1497\n",
      "Epoch 32/50\n",
      "440054/440054 [==============================] - 9s 20us/step - loss: 0.1590 - val_loss: 0.1693\n",
      "Epoch 33/50\n",
      "440054/440054 [==============================] - 8s 18us/step - loss: 0.1582 - val_loss: 0.1513\n",
      "Epoch 34/50\n",
      "440054/440054 [==============================] - 8s 19us/step - loss: 0.1584 - val_loss: 0.1514\n",
      "Epoch 35/50\n",
      "440054/440054 [==============================] - 8s 18us/step - loss: 0.1563 - val_loss: 0.2626\n",
      "Epoch 36/50\n",
      "440054/440054 [==============================] - 8s 18us/step - loss: 0.1572 - val_loss: 0.1606\n",
      "Epoch 37/50\n",
      "440054/440054 [==============================] - 8s 18us/step - loss: 0.1545 - val_loss: 0.1505\n",
      "Epoch 38/50\n",
      "440054/440054 [==============================] - 8s 18us/step - loss: 0.1552 - val_loss: 0.1544\n",
      "Epoch 39/50\n",
      "440054/440054 [==============================] - 8s 19us/step - loss: 0.1534 - val_loss: 0.1470\n",
      "Epoch 40/50\n",
      "440054/440054 [==============================] - 9s 20us/step - loss: 0.1524 - val_loss: 0.1544\n",
      "Epoch 41/50\n",
      "440054/440054 [==============================] - 9s 19us/step - loss: 0.1527 - val_loss: 0.1462\n",
      "Epoch 42/50\n",
      "440054/440054 [==============================] - 8s 19us/step - loss: 0.1556 - val_loss: 0.1499\n",
      "Epoch 43/50\n",
      "440054/440054 [==============================] - 8s 19us/step - loss: 0.1518 - val_loss: 0.1556\n",
      "Epoch 44/50\n",
      "440054/440054 [==============================] - 8s 18us/step - loss: 0.1509 - val_loss: 0.1500\n",
      "Epoch 45/50\n",
      "440054/440054 [==============================] - 8s 19us/step - loss: 0.1518 - val_loss: 0.1472\n",
      "Epoch 46/50\n",
      "440054/440054 [==============================] - 8s 19us/step - loss: 0.1510 - val_loss: 0.1483\n",
      "Epoch 47/50\n",
      "440054/440054 [==============================] - 9s 20us/step - loss: 0.1510 - val_loss: 0.1459\n",
      "Epoch 48/50\n",
      "440054/440054 [==============================] - 9s 20us/step - loss: 0.1524 - val_loss: 0.1483\n",
      "Epoch 49/50\n",
      "440054/440054 [==============================] - 8s 19us/step - loss: 0.1513 - val_loss: 0.1460\n",
      "Epoch 50/50\n",
      "440054/440054 [==============================] - 8s 19us/step - loss: 0.1537 - val_loss: 0.1485\n",
      "Done training...\n",
      "time elapsed:454.23\n"
     ]
    }
   ],
   "source": [
    "from keras.models import Sequential\n",
    "from keras.layers import Dense\n",
    "from keras.layers import Dropout\n",
    "from keras import regularizers\n",
    "from keras.optimizers import Adam\n",
    "import keras.backend as K\n",
    "np.random.seed(7)\n",
    "dims = X_train.shape[1]\n",
    "learning_rate = 0.001\n",
    "training_epochs = 50\n",
    "display_step = 200\n",
    "factor = 0.25\n",
    "batch = 128\n",
    "seed = 7\n",
    "beta_1 = 0.9\n",
    "beta_2 = 0.999\n",
    "epsilon = None\n",
    "decay = 0.001\n",
    "np.random.seed(seed)\n",
    "K.clear_session()\n",
    "model = Sequential()\n",
    "model.add(Dense(225, input_shape=(dims,), activation='relu',\n",
    "                kernel_regularizer=regularizers.l2(0.01)))\n",
    "model.add(Dense(115, activation = 'relu'))\n",
    "model.add(Dense(55, activation = 'relu'))\n",
    "model.add(Dense(27, activation = 'relu'))\n",
    "model.add(Dense(1, activation = 'relu'))\n",
    "model.compile(Adam(lr=learning_rate), 'mean_squared_error')\n",
    "\n",
    "start = time.time()\n",
    "print(\"Building model...\" )\n",
    "history = model.fit(X_train, y_train, validation_data = (X_test, y_test), epochs=training_epochs, \n",
    "          batch_size=batch,\n",
    "          verbose=1)\n",
    "end = time.time()\n",
    "print(\"Done training...\")\n",
    "print(\"time elapsed:%.2f\"%(end-start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save('purchase_predictor.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_dl = "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred_train = model.predict(X_train)\n",
    "print('Train',np.sum(np.abs(np.exp(y_pred_train)-np.exp(y_train))/np.exp(y_train)<0.25)/y_train.shape[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "37.71666666666667"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred_test = model.predict(X_test)\n",
    "print('Validation',np.sum(np.abs(np.exp(y_pred_test)-np.exp(y_test))/np.exp(y_test)<0.25)/y_test.shape[0])\n",
    "\n",
    "print(np.sqrt(mean_squared_error(np.exp(y_test),np.exp(y_pred_test))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "440054/440054 [==============================] - 8s 19us/step\n"
     ]
    }
   ],
   "source": [
    "y_pred_train = model.predict(X_train, verbose=1)\n",
    "print('Train',np.sum(np.abs(np.exp(y_pred_train)-np.exp(y_train))/np.exp(y_train)<0.25)/y_train.shape[0])\n",
    "\n",
    "# y_pred_test = model.predict(X_test)\n",
    "# print('Validation',np.sum(np.abs(np.exp(y_pred_test)-np.exp(y_test))/np.exp(y_test)<0.25)/y_test.shape[0])\n",
    "\n",
    "# print(np.sqrt(mean_squared_error(np.exp(y_test),np.exp(y_pred_test))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'history' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-1-170f5ce87cc0>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mhistory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhistory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'history' is not defined"
     ]
    }
   ],
   "source": [
    "history.history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
